#  Copyright (c) 2021, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

from coremltools.converters.mil.mil import Builder as mb, types as types
from coremltools.converters.mil.mil.passes.pass_registry import register_pass
from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass

@register_pass(namespace="common")
class update_output_dtypes(AbstractGraphPass):
    """
    Update the dtypes of output vars of the main block to match the dtypes
    provided in prog.main_output_types, which in turn is populated by the
    "outputs" argument provided by the user in the coremltools.convert() API.
    This graph pass assumes that the list of outputs in prog.main_output_types (if not None),
    are in the same order as the output vars.
    """

    def apply(self, prog):
        user_provided_output_types = prog.main_output_types
        main_func = prog.functions["main"]
        output_vars = main_func.outputs
        if user_provided_output_types is None or len(user_provided_output_types) == 0:
            return
        if len(output_vars) != len(user_provided_output_types):
            msg = "Number of outputs provided by the user, which is {}, " \
                  "does not match the number of outputs generated by the model, which is {}"
            raise ValueError(msg.format(len(user_provided_output_types), len(output_vars)))

        new_outputs = []
        for i, output_type in enumerate(user_provided_output_types):
            required_output_dtype = output_type.dtype
            output_var = output_vars[i]
            if required_output_dtype is None or \
                not (types.is_tensor(output_var.sym_type) or types.is_scalar(output_var.sym_type)) or \
                required_output_dtype == output_var.dtype:
                # no need to update the output var's dtype in this case
                new_outputs.append(output_var)
            else:
                output_var_name = output_var.name
                output_var.set_name(output_var_name + "_type_" + types.builtin_to_string(output_var.dtype))
                with main_func:
                    output_var = mb.cast(x=output_var, dtype=types.builtin_to_string(required_output_dtype))
                    output_var.set_name(output_var_name)
                new_outputs.append(output_var)

        main_func.set_outputs(new_outputs)





